# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""
Test :mod:`astropy.io.registry`.

.. todo::

    Don't rely on Table for tests

"""

import os
import sys
from collections import Counter
from copy import deepcopy
from io import StringIO

import numpy as np
import pytest

import astropy.units as u
from astropy.io import registry as io_registry
from astropy.io.registry import (
    IORegistryError,
    UnifiedInputRegistry,
    UnifiedIORegistry,
    UnifiedOutputRegistry,
    compat,
)
from astropy.io.registry.base import _UnifiedIORegistryBase
from astropy.io.registry.compat import default_registry
from astropy.table import Table

SKIPIF_OPTIMIZED_PYTHON = pytest.mark.skipif(
    sys.flags.optimize >= 2, reason="docstrings are not available at runtime"
)
ONLY_OPTIMIZED_PYTHON = pytest.mark.skipif(
    sys.flags.optimize < 2,
    reason="checking behavior specifically if docstrings are not present",
)


###############################################################################
# pytest setup and fixtures


class UnifiedIORegistryBaseSubClass(_UnifiedIORegistryBase):
    """Non-abstract subclass of UnifiedIORegistryBase for testing."""

    def get_formats(self, data_class=None):
        return None


class EmptyData:
    """
    Thing that can read and write.
    Note that the read/write are the compatibility methods, which allow for the
    kwarg ``registry``. This allows us to not subclass ``EmptyData`` for each
    of the types of registry (read-only, ...) and use this class everywhere.
    """

    read = classmethod(io_registry.read)
    write = io_registry.write


class OtherEmptyData:
    """A different class with different I/O"""

    read = classmethod(io_registry.read)
    write = io_registry.write


def empty_reader(*args, **kwargs):
    return EmptyData()


def empty_writer(table, *args, **kwargs):
    return "status: success"


def empty_identifier(*args, **kwargs):
    return True


@pytest.fixture
def fmtcls1():
    return ("test1", EmptyData)


@pytest.fixture
def fmtcls2():
    return ("test2", EmptyData)


@pytest.fixture(params=["test1", "test2"])
def fmtcls(request):
    return (request.param, EmptyData)


@pytest.fixture
def original():
    ORIGINAL = {}
    ORIGINAL["readers"] = deepcopy(default_registry._readers)
    ORIGINAL["writers"] = deepcopy(default_registry._writers)
    ORIGINAL["identifiers"] = deepcopy(default_registry._identifiers)
    return ORIGINAL


###############################################################################


def test_fmcls1_fmtcls2(fmtcls1, fmtcls2):
    """Just check a fact that we rely on in other tests."""
    assert fmtcls1[1] is fmtcls2[1]


def test_IORegistryError():
    with pytest.raises(IORegistryError, match="just checking"):
        raise IORegistryError("just checking")


class TestUnifiedIORegistryBase:
    """Test :class:`astropy.io.registry.UnifiedIORegistryBase`."""

    def setup_class(self):
        """Setup class. This is called 1st by pytest."""
        self._cls = UnifiedIORegistryBaseSubClass

    @pytest.fixture
    def registry(self):
        """I/O registry. Cleaned before and after each function."""
        registry = self._cls()

        HAS_READERS = hasattr(registry, "_readers")
        HAS_WRITERS = hasattr(registry, "_writers")

        # copy and clear original registry
        ORIGINAL = {}
        ORIGINAL["identifiers"] = deepcopy(registry._identifiers)
        registry._identifiers.clear()
        if HAS_READERS:
            ORIGINAL["readers"] = deepcopy(registry._readers)
            registry._readers.clear()
        if HAS_WRITERS:
            ORIGINAL["writers"] = deepcopy(registry._writers)
            registry._writers.clear()

        yield registry

        registry._identifiers.clear()
        registry._identifiers.update(ORIGINAL["identifiers"])
        if HAS_READERS:
            registry._readers.clear()
            registry._readers.update(ORIGINAL["readers"])
        if HAS_WRITERS:
            registry._writers.clear()
            registry._writers.update(ORIGINAL["writers"])

    # ===========================================

    def test_get_formats(self, registry):
        """Test ``registry.get_formats()``."""
        # defaults
        assert registry.get_formats() is None
        # (kw)args don't matter
        assert registry.get_formats(data_class=24) is None

    @SKIPIF_OPTIMIZED_PYTHON
    def test_delay_doc_updates(self, registry, fmtcls1):
        """Test ``registry.delay_doc_updates()``."""
        # TODO! figure out what can be tested
        with registry.delay_doc_updates(EmptyData):
            registry.register_identifier(*fmtcls1, empty_identifier)

    @ONLY_OPTIMIZED_PYTHON
    def test_compat_delay_doc_updates_optimized_mode(self, registry, fmtcls1):
        # check that the context manager doesn't raise an exception
        # but otherwise it should be a pure no-op
        with registry.delay_doc_updates(EmptyData):
            pass

    def test_register_identifier(self, registry, fmtcls1, fmtcls2):
        """Test ``registry.register_identifier()``."""
        # initial check it's not registered
        assert fmtcls1 not in registry._identifiers
        assert fmtcls2 not in registry._identifiers

        # register
        registry.register_identifier(*fmtcls1, empty_identifier)
        registry.register_identifier(*fmtcls2, empty_identifier)
        assert fmtcls1 in registry._identifiers
        assert fmtcls2 in registry._identifiers

    def test_register_identifier_invalid(self, registry, fmtcls):
        """Test calling ``registry.register_identifier()`` twice."""
        fmt, cls = fmtcls
        registry.register_identifier(fmt, cls, empty_identifier)
        with pytest.raises(IORegistryError) as exc:
            registry.register_identifier(fmt, cls, empty_identifier)
        assert (
            str(exc.value) == f"Identifier for format '{fmt}' and class "
            f"'{cls.__name__}' is already defined"
        )

    def test_register_identifier_force(self, registry, fmtcls1):
        registry.register_identifier(*fmtcls1, empty_identifier)
        registry.register_identifier(*fmtcls1, empty_identifier, force=True)
        assert fmtcls1 in registry._identifiers

    # -----------------------

    def test_unregister_identifier(self, registry, fmtcls1):
        """Test ``registry.unregister_identifier()``."""
        registry.register_identifier(*fmtcls1, empty_identifier)
        assert fmtcls1 in registry._identifiers

        registry.unregister_identifier(*fmtcls1)
        assert fmtcls1 not in registry._identifiers

    def test_unregister_identifier_invalid(self, registry, fmtcls):
        """Test ``registry.unregister_identifier()``."""
        fmt, cls = fmtcls
        with pytest.raises(IORegistryError) as exc:
            registry.unregister_identifier(fmt, cls)
        assert (
            str(exc.value)
            == f"No identifier defined for format '{fmt}' and class '{cls.__name__}'"
        )

    def test_identify_format(self, registry, fmtcls1):
        """Test ``registry.identify_format()``."""
        fmt, cls = fmtcls1
        args = (None, cls, None, None, (None,), {})

        # test no formats to identify
        formats = registry.identify_format(*args)
        assert formats == []

        # test there is a format to identify
        registry.register_identifier(fmt, cls, empty_identifier)
        formats = registry.identify_format(*args)
        assert fmt in formats

    # ===========================================
    # Compat tests

    def test_compat_register_identifier(self, registry, fmtcls1):
        # with registry specified
        assert fmtcls1 not in registry._identifiers
        compat.register_identifier(*fmtcls1, empty_identifier, registry=registry)
        assert fmtcls1 in registry._identifiers

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            assert fmtcls1 not in default_registry._identifiers
            try:
                compat.register_identifier(*fmtcls1, empty_identifier)
            except Exception:
                pass
            else:
                assert fmtcls1 in default_registry._identifiers
            finally:
                default_registry._identifiers.pop(fmtcls1)

    def test_compat_unregister_identifier(self, registry, fmtcls1):
        # with registry specified
        registry.register_identifier(*fmtcls1, empty_identifier)
        assert fmtcls1 in registry._identifiers
        compat.unregister_identifier(*fmtcls1, registry=registry)
        assert fmtcls1 not in registry._identifiers

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            assert fmtcls1 not in default_registry._identifiers
            default_registry.register_identifier(*fmtcls1, empty_identifier)
            assert fmtcls1 in default_registry._identifiers
            compat.unregister_identifier(*fmtcls1)
            assert fmtcls1 not in registry._identifiers

    def test_compat_identify_format(self, registry, fmtcls1):
        fmt, cls = fmtcls1
        args = (None, cls, None, None, (None,), {})

        # with registry specified
        registry.register_identifier(*fmtcls1, empty_identifier)
        formats = compat.identify_format(*args, registry=registry)
        assert fmt in formats

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            try:
                default_registry.register_identifier(*fmtcls1, empty_identifier)
            except Exception:
                pass
            else:
                formats = compat.identify_format(*args)
                assert fmt in formats
            finally:
                default_registry.unregister_identifier(*fmtcls1)

    @pytest.mark.skip("TODO!")
    def test_compat_get_formats(self, registry, fmtcls1):
        raise AssertionError()

    @pytest.mark.skip("TODO!")
    def test_compat_delay_doc_updates(self, registry, fmtcls1):
        raise AssertionError()


class TestUnifiedInputRegistry(TestUnifiedIORegistryBase):
    """Test :class:`astropy.io.registry.UnifiedInputRegistry`."""

    def setup_class(self):
        """Setup class. This is called 1st by pytest."""
        self._cls = UnifiedInputRegistry

    # ===========================================

    def test_inherited_read_registration(self, registry):
        # check that multi-generation inheritance works properly,
        # meaning that a child inherits from parents before
        # grandparents, see astropy/astropy#7156

        class Child1(EmptyData):
            pass

        class Child2(Child1):
            pass

        def _read():
            return EmptyData()

        def _read1():
            return Child1()

        # check that reader gets inherited
        registry.register_reader("test", EmptyData, _read)
        assert registry.get_reader("test", Child2) is _read

        # check that nearest ancestor is identified
        # (i.e. that the reader for Child2 is the registered method
        #  for Child1, and not Table)
        registry.register_reader("test", Child1, _read1)
        assert registry.get_reader("test", Child2) is _read1

    # ===========================================

    @pytest.mark.skip("TODO!")
    def test_get_formats(self, registry):
        """Test ``registry.get_formats()``."""
        raise AssertionError()

    @SKIPIF_OPTIMIZED_PYTHON
    def test_delay_doc_updates(self, registry, fmtcls1):
        """Test ``registry.delay_doc_updates()``."""
        super().test_delay_doc_updates(registry, fmtcls1)

        with registry.delay_doc_updates(EmptyData):
            registry.register_reader("test", EmptyData, empty_reader)

            # test that the doc has not yet been updated.
            # if a the format was registered in a different way, then
            # test that this method is not present.
            if "Format" in EmptyData.read.__doc__:
                docs = EmptyData.read.__doc__.split("\n")
                ihd = [i for i, s in enumerate(docs) if ("Format" in s)][0]
                ifmt = docs[ihd].index("Format") + 1
                iread = docs[ihd].index("Read") + 1
                # there might not actually be anything here, which is also good
                if docs[-2] != docs[-1]:
                    assert docs[-1][ifmt : ifmt + 5] == "test"
                    assert docs[-1][iread : iread + 3] != "Yes"
        # now test it's updated
        docs = EmptyData.read.__doc__.split("\n")
        ihd = [i for i, s in enumerate(docs) if ("Format" in s)][0]
        ifmt = docs[ihd].index("Format") + 2
        iread = docs[ihd].index("Read") + 1
        assert docs[-2][ifmt : ifmt + 4] == "test"
        assert docs[-2][iread : iread + 3] == "Yes"

    def test_identify_read_format(self, registry):
        """Test ``registry.identify_format()``."""
        args = ("read", EmptyData, None, None, (None,), {})

        # test there is no format to identify
        formats = registry.identify_format(*args)
        assert formats == []

        # test there is a format to identify
        # doesn't actually matter if register a reader, it returns True for all
        registry.register_identifier("test", EmptyData, empty_identifier)
        formats = registry.identify_format(*args)
        assert "test" in formats

    # -----------------------

    def test_register_reader(self, registry, fmtcls1, fmtcls2):
        """Test ``registry.register_reader()``."""
        # initial check it's not registered
        assert fmtcls1 not in registry._readers
        assert fmtcls2 not in registry._readers

        # register
        registry.register_reader(*fmtcls1, empty_reader)
        registry.register_reader(*fmtcls2, empty_reader)
        assert fmtcls1 in registry._readers
        assert fmtcls2 in registry._readers
        assert registry._readers[fmtcls1] == (empty_reader, 0)  # (f, priority)
        assert registry._readers[fmtcls2] == (empty_reader, 0)  # (f, priority)

    def test_register_reader_invalid(self, registry, fmtcls1):
        fmt, cls = fmtcls1
        registry.register_reader(fmt, cls, empty_reader)
        with pytest.raises(IORegistryError) as exc:
            registry.register_reader(fmt, cls, empty_reader)
        assert (
            str(exc.value) == f"Reader for format '{fmt}' and class "
            f"'{cls.__name__}' is already defined"
        )

    def test_register_reader_force(self, registry, fmtcls1):
        registry.register_reader(*fmtcls1, empty_reader)
        registry.register_reader(*fmtcls1, empty_reader, force=True)
        assert fmtcls1 in registry._readers

    def test_register_readers_with_same_name_on_different_classes(self, registry):
        # No errors should be generated if the same name is registered for
        # different objects...but this failed under python3
        registry.register_reader("test", EmptyData, lambda: EmptyData())
        registry.register_reader("test", OtherEmptyData, lambda: OtherEmptyData())
        t = EmptyData.read(format="test", registry=registry)
        assert isinstance(t, EmptyData)
        tbl = OtherEmptyData.read(format="test", registry=registry)
        assert isinstance(tbl, OtherEmptyData)

    # -----------------------

    def test_unregister_reader(self, registry, fmtcls1):
        """Test ``registry.unregister_reader()``."""
        registry.register_reader(*fmtcls1, empty_reader)
        assert fmtcls1 in registry._readers

        registry.unregister_reader(*fmtcls1)
        assert fmtcls1 not in registry._readers

    def test_unregister_reader_invalid(self, registry, fmtcls1):
        fmt, cls = fmtcls1
        with pytest.raises(IORegistryError) as exc:
            registry.unregister_reader(*fmtcls1)
        assert (
            str(exc.value)
            == f"No reader defined for format '{fmt}' and class '{cls.__name__}'"
        )

    # -----------------------

    def test_get_reader(self, registry, fmtcls):
        """Test ``registry.get_reader()``."""
        fmt, cls = fmtcls
        with pytest.raises(IORegistryError):
            registry.get_reader(fmt, cls)

        registry.register_reader(fmt, cls, empty_reader)
        reader = registry.get_reader(fmt, cls)
        assert reader is empty_reader

    def test_get_reader_invalid(self, registry, fmtcls):
        fmt, cls = fmtcls
        with pytest.raises(IORegistryError) as exc:
            registry.get_reader(fmt, cls)
        assert str(exc.value).startswith(
            f"No reader defined for format '{fmt}' and class '{cls.__name__}'"
        )

    # -----------------------

    def test_read_noformat(self, registry, fmtcls1):
        """Test ``registry.read()`` when there isn't a reader."""
        with pytest.raises(IORegistryError) as exc:
            fmtcls1[1].read(registry=registry)
        assert str(exc.value).startswith(
            "Format could not be identified based"
            " on the file name or contents, "
            "please provide a 'format' argument."
        )

    def test_read_noformat_arbitrary(self, registry, original, fmtcls1):
        """Test that all identifier functions can accept arbitrary input"""
        registry._identifiers.update(original["identifiers"])
        with pytest.raises(IORegistryError) as exc:
            fmtcls1[1].read(object(), registry=registry)
        assert str(exc.value).startswith(
            "Format could not be identified based"
            " on the file name or contents, "
            "please provide a 'format' argument."
        )

    def test_read_noformat_arbitrary_file(self, tmp_path, registry, original):
        """Tests that all identifier functions can accept arbitrary files"""
        registry._readers.update(original["readers"])
        testfile = tmp_path / "foo.example"
        with open(testfile, "w") as f:
            f.write("Hello world")

        with pytest.raises(IORegistryError) as exc:
            Table.read(testfile)
        assert str(exc.value).startswith(
            "Format could not be identified based"
            " on the file name or contents, "
            "please provide a 'format' argument."
        )

    def test_read_toomanyformats(self, registry, fmtcls1, fmtcls2):
        fmt1, cls = fmtcls1
        fmt2, _ = fmtcls2

        registry.register_identifier(fmt1, cls, lambda o, *x, **y: True)
        registry.register_identifier(fmt2, cls, lambda o, *x, **y: True)
        with pytest.raises(IORegistryError) as exc:
            cls.read(registry=registry)
        assert str(exc.value) == f"Format is ambiguous - options are: {fmt1}, {fmt2}"

    def test_read_uses_priority(self, registry, fmtcls1, fmtcls2):
        fmt1, cls = fmtcls1
        fmt2, _ = fmtcls2
        counter = Counter()

        def counting_reader1(*args, **kwargs):
            counter[fmt1] += 1
            return cls()

        def counting_reader2(*args, **kwargs):
            counter[fmt2] += 1
            return cls()

        registry.register_reader(fmt1, cls, counting_reader1, priority=1)
        registry.register_reader(fmt2, cls, counting_reader2, priority=2)
        registry.register_identifier(fmt1, cls, lambda o, *x, **y: True)
        registry.register_identifier(fmt2, cls, lambda o, *x, **y: True)

        cls.read(registry=registry)
        assert counter[fmt2] == 1
        assert counter[fmt1] == 0

    def test_read_format_noreader(self, registry, fmtcls1):
        fmt, cls = fmtcls1
        with pytest.raises(IORegistryError) as exc:
            cls.read(format=fmt, registry=registry)
        assert str(exc.value).startswith(
            f"No reader defined for format '{fmt}' and class '{cls.__name__}'"
        )

    def test_read_identifier(self, tmp_path, registry, fmtcls1, fmtcls2):
        fmt1, cls = fmtcls1
        fmt2, _ = fmtcls2

        registry.register_identifier(
            fmt1, cls, lambda o, path, fileobj, *x, **y: path.endswith("a")
        )
        registry.register_identifier(
            fmt2, cls, lambda o, path, fileobj, *x, **y: path.endswith("b")
        )

        # Now check that we got past the identifier and are trying to get
        # the reader. The registry.get_reader will fail but the error message
        # will tell us if the identifier worked.

        filename = tmp_path / "testfile.a"
        open(filename, "w").close()
        with pytest.raises(IORegistryError) as exc:
            cls.read(filename, registry=registry)
        assert str(exc.value).startswith(
            f"No reader defined for format '{fmt1}' and class '{cls.__name__}'"
        )

        filename = tmp_path / "testfile.b"
        open(filename, "w").close()
        with pytest.raises(IORegistryError) as exc:
            cls.read(filename, registry=registry)
        assert str(exc.value).startswith(
            f"No reader defined for format '{fmt2}' and class '{cls.__name__}'"
        )

    def test_read_valid_return(self, registry, fmtcls):
        fmt, cls = fmtcls
        registry.register_reader(fmt, cls, empty_reader)
        t = cls.read(format=fmt, registry=registry)
        assert isinstance(t, cls)

    def test_read_non_existing_unknown_ext(self, fmtcls1):
        """Raise the correct error when attempting to read a non-existing
        file with an unknown extension."""
        with pytest.raises(OSError):
            data = fmtcls1[1].read("non-existing-file-with-unknown.ext")

    def test_read_directory(self, tmp_path, registry, fmtcls1):
        """
        Regression test for a bug that caused the I/O registry infrastructure to
        not work correctly for datasets that are represented by folders as
        opposed to files, when using the descriptors to add read/write methods.
        """
        _, cls = fmtcls1
        registry.register_identifier(
            "test_folder_format", cls, lambda o, *x, **y: o == "read"
        )
        registry.register_reader("test_folder_format", cls, empty_reader)

        filename = tmp_path / "folder_dataset"
        filename.mkdir()

        # With the format explicitly specified
        dataset = cls.read(filename, format="test_folder_format", registry=registry)
        assert isinstance(dataset, cls)

        # With the auto-format identification
        dataset = cls.read(filename, registry=registry)
        assert isinstance(dataset, cls)

    # ===========================================
    # Compat tests

    def test_compat_register_reader(self, registry, fmtcls1):
        # with registry specified
        assert fmtcls1 not in registry._readers
        compat.register_reader(*fmtcls1, empty_reader, registry=registry)
        assert fmtcls1 in registry._readers

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            assert fmtcls1 not in default_registry._readers
            try:
                compat.register_reader(*fmtcls1, empty_identifier)
            except Exception:
                pass
            else:
                assert fmtcls1 in default_registry._readers
            finally:
                default_registry._readers.pop(fmtcls1)

    def test_compat_unregister_reader(self, registry, fmtcls1):
        # with registry specified
        registry.register_reader(*fmtcls1, empty_reader)
        assert fmtcls1 in registry._readers
        compat.unregister_reader(*fmtcls1, registry=registry)
        assert fmtcls1 not in registry._readers

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            assert fmtcls1 not in default_registry._readers
            default_registry.register_reader(*fmtcls1, empty_reader)
            assert fmtcls1 in default_registry._readers
            compat.unregister_reader(*fmtcls1)
            assert fmtcls1 not in registry._readers

    def test_compat_get_reader(self, registry, fmtcls1):
        # with registry specified
        registry.register_reader(*fmtcls1, empty_reader)
        reader = compat.get_reader(*fmtcls1, registry=registry)
        assert reader is empty_reader
        registry.unregister_reader(*fmtcls1)

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            default_registry.register_reader(*fmtcls1, empty_reader)
            reader = compat.get_reader(*fmtcls1)
            assert reader is empty_reader
            default_registry.unregister_reader(*fmtcls1)

    def test_compat_read(self, registry, fmtcls1):
        fmt, cls = fmtcls1
        # with registry specified
        registry.register_reader(*fmtcls1, empty_reader)
        t = compat.read(cls, format=fmt, registry=registry)
        assert isinstance(t, cls)
        registry.unregister_reader(*fmtcls1)

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            default_registry.register_reader(*fmtcls1, empty_reader)
            t = compat.read(cls, format=fmt)
            assert isinstance(t, cls)
            default_registry.unregister_reader(*fmtcls1)


class TestUnifiedOutputRegistry(TestUnifiedIORegistryBase):
    """Test :class:`astropy.io.registry.UnifiedOutputRegistry`."""

    def setup_class(self):
        """Setup class. This is called 1st by pytest."""
        self._cls = UnifiedOutputRegistry

    # ===========================================

    def test_inherited_write_registration(self, registry):
        # check that multi-generation inheritance works properly,
        # meaning that a child inherits from parents before
        # grandparents, see astropy/astropy#7156

        class Child1(EmptyData):
            pass

        class Child2(Child1):
            pass

        def _write():
            return EmptyData()

        def _write1():
            return Child1()

        # check that writer gets inherited
        registry.register_writer("test", EmptyData, _write)
        assert registry.get_writer("test", Child2) is _write

        # check that nearest ancestor is identified
        # (i.e. that the writer for Child2 is the registered method
        #  for Child1, and not Table)
        registry.register_writer("test", Child1, _write1)
        assert registry.get_writer("test", Child2) is _write1

    # ===========================================

    @SKIPIF_OPTIMIZED_PYTHON
    def test_delay_doc_updates(self, registry, fmtcls1):
        """Test ``registry.delay_doc_updates()``."""
        super().test_delay_doc_updates(registry, fmtcls1)
        fmt, cls = fmtcls1

        with registry.delay_doc_updates(EmptyData):
            registry.register_writer(*fmtcls1, empty_writer)

            # test that the doc has not yet been updated.
            # if a the format was registered in a different way, then
            # test that this method is not present.
            if "Format" in EmptyData.read.__doc__:
                docs = EmptyData.write.__doc__.split("\n")
                ihd = [i for i, s in enumerate(docs) if ("Format" in s)][0]
                ifmt = docs[ihd].index("Format")
                iwrite = docs[ihd].index("Write") + 1
                # there might not actually be anything here, which is also good
                if docs[-2] != docs[-1]:
                    assert fmt in docs[-1][ifmt : ifmt + len(fmt) + 1]
                    assert docs[-1][iwrite : iwrite + 3] != "Yes"
        # now test it's updated
        docs = EmptyData.write.__doc__.split("\n")
        ihd = [i for i, s in enumerate(docs) if ("Format" in s)][0]
        ifmt = docs[ihd].index("Format") + 1
        iwrite = docs[ihd].index("Write") + 2
        assert fmt in docs[-2][ifmt : ifmt + len(fmt) + 1]
        assert docs[-2][iwrite : iwrite + 3] == "Yes"

    @pytest.mark.skip("TODO!")
    def test_get_formats(self, registry):
        """Test ``registry.get_formats()``."""
        raise AssertionError()

    def test_identify_write_format(self, registry, fmtcls1):
        """Test ``registry.identify_format()``."""
        fmt, cls = fmtcls1
        args = ("write", cls, None, None, (None,), {})

        # test there is no format to identify
        formats = registry.identify_format(*args)
        assert formats == []

        # test there is a format to identify
        # doesn't actually matter if register a writer, it returns True for all
        registry.register_identifier(fmt, cls, empty_identifier)
        formats = registry.identify_format(*args)
        assert fmt in formats

    # -----------------------

    def test_register_writer(self, registry, fmtcls1, fmtcls2):
        """Test ``registry.register_writer()``."""
        # initial check it's not registered
        assert fmtcls1 not in registry._writers
        assert fmtcls2 not in registry._writers

        # register
        registry.register_writer(*fmtcls1, empty_writer)
        registry.register_writer(*fmtcls2, empty_writer)
        assert fmtcls1 in registry._writers
        assert fmtcls2 in registry._writers

    def test_register_writer_invalid(self, registry, fmtcls):
        """Test calling ``registry.register_writer()`` twice."""
        fmt, cls = fmtcls
        registry.register_writer(fmt, cls, empty_writer)
        with pytest.raises(IORegistryError) as exc:
            registry.register_writer(fmt, cls, empty_writer)
        assert (
            str(exc.value) == f"Writer for format '{fmt}' and class "
            f"'{cls.__name__}' is already defined"
        )

    def test_register_writer_force(self, registry, fmtcls1):
        registry.register_writer(*fmtcls1, empty_writer)
        registry.register_writer(*fmtcls1, empty_writer, force=True)
        assert fmtcls1 in registry._writers

    # -----------------------

    def test_unregister_writer(self, registry, fmtcls1):
        """Test ``registry.unregister_writer()``."""
        registry.register_writer(*fmtcls1, empty_writer)
        assert fmtcls1 in registry._writers

        registry.unregister_writer(*fmtcls1)
        assert fmtcls1 not in registry._writers

    def test_unregister_writer_invalid(self, registry, fmtcls):
        """Test ``registry.unregister_writer()``."""
        fmt, cls = fmtcls
        with pytest.raises(IORegistryError) as exc:
            registry.unregister_writer(fmt, cls)
        assert (
            str(exc.value)
            == f"No writer defined for format '{fmt}' and class '{cls.__name__}'"
        )

    # -----------------------

    def test_get_writer(self, registry, fmtcls1):
        """Test ``registry.get_writer()``."""
        with pytest.raises(IORegistryError):
            registry.get_writer(*fmtcls1)

        registry.register_writer(*fmtcls1, empty_writer)
        writer = registry.get_writer(*fmtcls1)
        assert writer is empty_writer

    def test_get_writer_invalid(self, registry, fmtcls1):
        """Test invalid ``registry.get_writer()``."""
        fmt, cls = fmtcls1
        with pytest.raises(IORegistryError) as exc:
            registry.get_writer(fmt, cls)
        assert str(exc.value).startswith(
            f"No writer defined for format '{fmt}' and class '{cls.__name__}'"
        )

    # -----------------------

    def test_write_noformat(self, registry, fmtcls1):
        """Test ``registry.write()`` when there isn't a writer."""
        with pytest.raises(IORegistryError) as exc:
            fmtcls1[1]().write(registry=registry)
        assert str(exc.value).startswith(
            "Format could not be identified based"
            " on the file name or contents, "
            "please provide a 'format' argument."
        )

    def test_write_noformat_arbitrary(self, registry, original, fmtcls1):
        """Test that all identifier functions can accept arbitrary input"""

        registry._identifiers.update(original["identifiers"])
        with pytest.raises(IORegistryError) as exc:
            fmtcls1[1]().write(object(), registry=registry)
        assert str(exc.value).startswith(
            "Format could not be identified based"
            " on the file name or contents, "
            "please provide a 'format' argument."
        )

    def test_write_noformat_arbitrary_file(self, tmp_path, registry, original):
        """Tests that all identifier functions can accept arbitrary files"""
        registry._writers.update(original["writers"])
        testfile = tmp_path / "foo.example"

        with pytest.raises(IORegistryError) as exc:
            Table().write(testfile, registry=registry)
        assert str(exc.value).startswith(
            "Format could not be identified based"
            " on the file name or contents, "
            "please provide a 'format' argument."
        )

    def test_write_toomanyformats(self, registry, fmtcls1, fmtcls2):
        registry.register_identifier(*fmtcls1, lambda o, *x, **y: True)
        registry.register_identifier(*fmtcls2, lambda o, *x, **y: True)
        with pytest.raises(IORegistryError) as exc:
            fmtcls1[1]().write(registry=registry)
        assert (
            str(exc.value)
            == f"Format is ambiguous - options are: {fmtcls1[0]}, {fmtcls2[0]}"
        )

    def test_write_uses_priority(self, registry, fmtcls1, fmtcls2):
        fmt1, cls1 = fmtcls1
        fmt2, cls2 = fmtcls2
        counter = Counter()

        def counting_writer1(*args, **kwargs):
            counter[fmt1] += 1

        def counting_writer2(*args, **kwargs):
            counter[fmt2] += 1

        registry.register_writer(fmt1, cls1, counting_writer1, priority=1)
        registry.register_writer(fmt2, cls2, counting_writer2, priority=2)
        registry.register_identifier(fmt1, cls1, lambda o, *x, **y: True)
        registry.register_identifier(fmt2, cls2, lambda o, *x, **y: True)

        cls1().write(registry=registry)
        assert counter[fmt2] == 1
        assert counter[fmt1] == 0

    def test_write_format_nowriter(self, registry, fmtcls1):
        fmt, cls = fmtcls1
        with pytest.raises(IORegistryError) as exc:
            cls().write(format=fmt, registry=registry)
        assert str(exc.value).startswith(
            f"No writer defined for format '{fmt}' and class '{cls.__name__}'"
        )

    def test_write_identifier(self, registry, fmtcls1, fmtcls2):
        fmt1, cls = fmtcls1
        fmt2, _ = fmtcls2

        registry.register_identifier(fmt1, cls, lambda o, *x, **y: x[0].startswith("a"))
        registry.register_identifier(fmt2, cls, lambda o, *x, **y: x[0].startswith("b"))

        # Now check that we got past the identifier and are trying to get
        # the reader. The registry.get_writer will fail but the error message
        # will tell us if the identifier worked.
        with pytest.raises(IORegistryError) as exc:
            cls().write("abc", registry=registry)
        assert str(exc.value).startswith(
            f"No writer defined for format '{fmt1}' and class '{cls.__name__}'"
        )

        with pytest.raises(IORegistryError) as exc:
            cls().write("bac", registry=registry)
        assert str(exc.value).startswith(
            f"No writer defined for format '{fmt2}' and class '{cls.__name__}'"
        )

    def test_write_return(self, registry, fmtcls1):
        """Most writers will return None, but other values are not forbidden."""
        fmt, cls = fmtcls1
        registry.register_writer(fmt, cls, empty_writer)
        res = cls.write(cls(), format=fmt, registry=registry)
        assert res == "status: success"

    # ===========================================
    # Compat tests

    def test_compat_register_writer(self, registry, fmtcls1):
        # with registry specified
        assert fmtcls1 not in registry._writers
        compat.register_writer(*fmtcls1, empty_writer, registry=registry)
        assert fmtcls1 in registry._writers
        registry.unregister_writer(*fmtcls1)

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            assert fmtcls1 not in default_registry._writers
            try:
                compat.register_writer(*fmtcls1, empty_writer)
            except Exception:
                pass
            else:
                assert fmtcls1 in default_registry._writers
            finally:
                default_registry._writers.pop(fmtcls1)

    def test_compat_unregister_writer(self, registry, fmtcls1):
        # with registry specified
        registry.register_writer(*fmtcls1, empty_writer)
        assert fmtcls1 in registry._writers
        compat.unregister_writer(*fmtcls1, registry=registry)
        assert fmtcls1 not in registry._writers

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            assert fmtcls1 not in default_registry._writers
            default_registry.register_writer(*fmtcls1, empty_writer)
            assert fmtcls1 in default_registry._writers
            compat.unregister_writer(*fmtcls1)
            assert fmtcls1 not in default_registry._writers

    def test_compat_get_writer(self, registry, fmtcls1):
        # with registry specified
        registry.register_writer(*fmtcls1, empty_writer)
        writer = compat.get_writer(*fmtcls1, registry=registry)
        assert writer is empty_writer

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            assert fmtcls1 not in default_registry._writers
            default_registry.register_writer(*fmtcls1, empty_writer)
            assert fmtcls1 in default_registry._writers
            writer = compat.get_writer(*fmtcls1)
            assert writer is empty_writer
            default_registry.unregister_writer(*fmtcls1)
            assert fmtcls1 not in default_registry._writers

    def test_compat_write(self, registry, fmtcls1):
        fmt, cls = fmtcls1

        # with registry specified
        registry.register_writer(*fmtcls1, empty_writer)
        res = compat.write(cls(), format=fmt, registry=registry)
        assert res == "status: success"

        # without registry specified it becomes default_registry
        if registry is not default_registry:
            assert fmtcls1 not in default_registry._writers
            default_registry.register_writer(*fmtcls1, empty_writer)
            assert fmtcls1 in default_registry._writers
            res = compat.write(cls(), format=fmt)
            assert res == "status: success"
            default_registry.unregister_writer(*fmtcls1)
            assert fmtcls1 not in default_registry._writers


class TestUnifiedIORegistry(TestUnifiedInputRegistry, TestUnifiedOutputRegistry):
    def setup_class(self):
        """Setup class. This is called 1st by pytest."""
        self._cls = UnifiedIORegistry

    # ===========================================

    @pytest.mark.skip("TODO!")
    def test_get_formats(self, registry):
        """Test ``registry.get_formats()``."""
        raise AssertionError()

    # -----------------------

    def test_identifier_origin(self, registry, fmtcls1, fmtcls2):
        fmt1, cls = fmtcls1
        fmt2, _ = fmtcls2

        registry.register_identifier(fmt1, cls, lambda o, *x, **y: o == "read")
        registry.register_identifier(fmt2, cls, lambda o, *x, **y: o == "write")
        registry.register_reader(fmt1, cls, empty_reader)
        registry.register_writer(fmt2, cls, empty_writer)

        # There should not be too many formats defined
        cls.read(registry=registry)
        cls().write(registry=registry)

        with pytest.raises(IORegistryError) as exc:
            cls.read(format=fmt2, registry=registry)
        assert str(exc.value).startswith(
            f"No reader defined for format '{fmt2}' and class '{cls.__name__}'"
        )

        with pytest.raises(IORegistryError) as exc:
            cls().write(format=fmt1, registry=registry)
        assert str(exc.value).startswith(
            f"No writer defined for format '{fmt1}' and class '{cls.__name__}'"
        )


class TestDefaultRegistry(TestUnifiedIORegistry):
    def setup_class(self):
        """Setup class. This is called 1st by pytest."""
        self._cls = lambda *args: default_registry


# =============================================================================
# Test compat
# much of this is already tested above since EmptyData uses io_registry.X(),
# which are the compat methods.


def test_dir():
    """Test all the compat methods are in the directory"""
    dc = dir(compat)
    for n in compat.__all__:
        assert n in dc


def test_getattr():
    for n in compat.__all__:
        assert hasattr(compat, n)

    with pytest.raises(AttributeError, match="module 'astropy.io.registry.compat'"):
        compat.this_is_definitely_not_in_this_module


# =============================================================================
# Table tests


def test_read_basic_table():
    registry = Table.read._registry
    data = np.array(
        list(zip([1, 2, 3], ["a", "b", "c"])), dtype=[("A", int), ("B", "|U1")]
    )
    try:
        registry.register_reader("test", Table, lambda x: Table(x))
    except Exception:
        pass
    else:
        t = Table.read(data, format="test")
        assert t.keys() == ["A", "B"]
        for i in range(3):
            assert t["A"][i] == data["A"][i]
            assert t["B"][i] == data["B"][i]
    finally:
        registry._readers.pop("test", None)


class TestSubclass:
    """
    Test using registry with a Table sub-class
    """

    @pytest.fixture(autouse=True)
    def registry(self):
        """I/O registry. Not cleaned."""
        return

    def test_read_table_subclass(self):
        class MyTable(Table):
            pass

        data = ["a b", "1 2"]
        mt = MyTable.read(data, format="ascii")
        t = Table.read(data, format="ascii")
        assert np.all(mt == t)
        assert mt.colnames == t.colnames
        assert type(mt) is MyTable

    def test_write_table_subclass(self):
        buffer = StringIO()

        class MyTable(Table):
            pass

        mt = MyTable([[1], [2]], names=["a", "b"])
        mt.write(buffer, format="ascii")
        assert buffer.getvalue() == os.linesep.join(["a b", "1 2", ""])

    def test_read_table_subclass_with_columns_attributes(self, tmp_path):
        """Regression test for https://github.com/astropy/astropy/issues/7181"""

        class MTable(Table):
            pass

        mt = MTable([[1, 2.5]], names=["a"])
        mt["a"].unit = u.m
        mt["a"].format = ".4f"
        mt["a"].description = "hello"

        testfile = tmp_path / "junk.fits"
        mt.write(testfile, overwrite=True)

        t = MTable.read(testfile)
        assert np.all(mt == t)
        assert mt.colnames == t.colnames
        assert type(t) is MTable
        assert t["a"].unit == u.m
        assert t["a"].format == "{:13.4f}"
        assert t["a"].description == "hello"
